"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn
from util import compute_gram_matrix
import torch.nn.functional as F


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07, augmented_features=True,  
                 decor_reg=0.1, decor_loss=True, 
                 ent_reg=0.1, ent_loss=True, 
                 spec_reg=0.1, spec_loss=True):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature
        self.augmented_features = augmented_features
        self.decor_reg = decor_reg
        self.decor_loss = decor_loss
        self.ent_reg = ent_reg
        self.ent_loss = ent_loss
        self.spec_reg = spec_reg
        self.spec_loss = spec_loss

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)
        # features = features[:, :, 0:200]
        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan] 
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        if self.decor_loss == True: 
            decor_loss = self.decorrelation_loss(features)
        else: 
            decor_loss = 0

        if self.ent_loss == True:
            entropy_loss = self.entropy_loss(features)
        else:
            entropy_loss = 0 

        if self.spec_loss == True:
            spectral_flattening_loss = self.flatten_spectrum_feature_loss(features) # trial 12 to 16 -> spectrum_loss = torch.mean((S_norm - 1) ** 2)
            # spectral_flattening_loss = self.flatten_spectrum_feature_loss(features) # trial 7 to 11 -> spectrum_loss = torch.var(S_norm)
            # spectral_flattening_loss = self.iso_logdet(features)
            # spectral_flattening_loss = self.iso_stable_rank(features) # trial 2 to 6
            # spectral_flattening_loss = self.spectral_flattening_regularizer(features) 

        else: 
            spectral_flattening_loss = 0

        total_loss = loss + self.decor_reg * decor_loss + self.ent_reg * entropy_loss + self.spec_reg * spectral_flattening_loss

        return total_loss, torch.tensor([decor_loss]), torch.tensor([entropy_loss]), torch.tensor([spectral_flattening_loss])


    def decorrelation_loss(self, features, mode='full'):
        """Compute decorrelation loss."""
        # Flatten features across all views
        bsz, n_views, dim = features.shape
        if self.augmented_features == True:
            features = features.view(bsz * n_views, dim)
        else: 
            features = features[:, 0, :]
        # Compute the covariance matrix
        features = features - features.mean(dim=0, keepdim=True)  # Center features

        # Compute covariance
        cov = (features.T @ features) / (features.shape[0] - 1)  # (D, D)

        # Identity target
        I = torch.eye(dim, device=features.device)

        if mode == 'full':
            # Frobenius norm of Cov - I
            decor_loss = torch.norm(cov - I, p='fro') / dim  # normalized
        elif mode == 'offdiag':
            off_diag = cov - torch.diag(torch.diag(cov))
            decor_loss = (off_diag ** 2).sum() / dim  # normalized
        else:
            raise ValueError(f"Unknown mode: {mode}")

        return decor_loss

    def entropy_loss(self, features):
        # perform SVD to get singular values
        bsz, n_views, dim = features.shape
        if self.augmented_features == True: 
            features = features.view(bsz * n_views, dim)
        else: 
            features = features[:, 0, :]
        U, S, V = torch.svd(features)
        # print(S.size())
        # print(S)

        # normalize singular values to form a probability distribution
        S_normalized = S / torch.sum(S)
    
        # calculate entropy
        entropy = -torch.sum(S_normalized * torch.log(S_normalized + 1e-6))
        return -entropy

    def flatten_spectrum_feature_loss(self, features):
        """
        Spectrum flattening regularization loss.
        Encourages singular values to be more evenly spread.

        Args:
            features (Tensor): shape (bsz, n_views, dim) or (bsz, 1, dim)
            augmented_features (bool): whether features have multiple views stacked
        Returns:
            loss (Tensor)
        """

        bsz, n_views, dim = features.shape

        if self.augmented_features:
            features = features.view(bsz * n_views, dim)
        else:
            features = features[:, 0, :]

        F_centered = features - features.mean(dim=0, keepdim=True)

        # Compute covariance matrix (d x d)
        # C = F_centered.T @ F_centered / F_centered.shape[0]-1  # (d x d)

        # Eigen decomposition
        # eigvals, eigvecs = torch.linalg.eigh(F_centered)  # eigvals ascending order
        # Compute SVD
        U, S, Vh = torch.linalg.svd(F_centered, full_matrices=False)
        # eigvals = torch.linalg.eigvalsh(F_centered)
        # S = torch.linalg.svdvals(F_centered)
        S_norm = S / S.sum()
        # S_norm = S / S.max()

        # Flatten the spectrum
        # spectrum_loss = - torch.sum(torch.log(S_norm.clamp(min=1e-6)))

        k = 5  # Number of smallest eigenvalues to target
        small_s_norm = S_norm[:k]
        inverse_smallest = 1.0 / (small_s_norm + 1e-6)
        spectrum_loss = torch.mean(inverse_smallest)
        # eigvals = S**2
        # eigvals = eigvals.clamp_min(1e-6)
        # p = eigvals / eigvals.sum()
        # kl_flatten = torch.sum(p * torch.log(p * eigvals.numel()))

        # spectrum_loss = torch.mean((S_norm - 1) ** 2) # waterbirds

        # spectrum_loss = torch.mean(S_norm ** 2)

        # spectrum_loss = torch.var(S_norm)

        # # Step 3: Select top-k and normalize
        # top_k = S[:5]
        # top_k_norm = top_k / S.max()

        # # Step 4: Take mean (or sum)
        # spectrum_loss = -torch.mean(top_k_norm)

        # Optionally scale (you can choose)
        # spectrum_loss = spectrum_loss / dim

        # Take top-k singular values
        # top_singular_values = S[:50]

        # Normalize (optional)
        # top_singular_values = top_singular_values / top_singular_values.max()

        # Compute variance
        # mean_top = top_singular_values.mean()
        # spectrum_loss = ((top_singular_values - mean_top) ** 2).mean()

        return spectrum_loss

    def spectral_flattening_regularizer(self, features):
        """
        Computes the spectral flattening regularizer for a batch of features.

        Args:
            features (torch.Tensor): Tensor of shape (batch_size, feature_dim),
                                    assumed to be the output of the encoder.

        Returns:
            torch.Tensor: Scalar regularization loss.
        """
        bsz, n_views, dim = features.shape

        if self.augmented_features:
            features = features.view(bsz * n_views, dim)
        else:
            features = features[:, 0, :]

        # Center the features (important for correct Gram matrix)
        features = features - features.mean(dim=0, keepdim=True)

        # Compute empirical Gram matrix G
        G = features.T @ features / features.size(0)  # Shape: (feature_dim, feature_dim)

        # Frobenius norm squared
        frob_norm_squared = torch.sum(G ** 2)

        # Trace squared
        trace_squared = torch.sum(torch.diag(G)) ** 2

        # Spectral flattening regularizer
        reg_loss = frob_norm_squared - trace_squared

        return reg_loss


    # ===== 1) log‑det (entropy) maximisation =================
    def iso_logdet(self, features, augmented_features=True, eps=1e-4):

        bsz, n_views, dim = features.shape

        if self.augmented_features:
            features = features.view(bsz * n_views, dim)
        else:
            features = features[:, 0, :]
        
        features = features.T 
        features = features - features.mean(dim=1, keepdim=True).detach()

        """Entropy‑maximisation regulariser  -log det(C + eps I)."""
        # empirical covariance  C = 1/B * F Fᵀ
        C = (features @ features.T) / features.shape[1]          # (d, d)
        # slogdet returns (sign, logabsdet); sign is always +1 for psd
        logdet = torch.linalg.slogdet(C + eps*torch.eye(dim, device=C.device))[1]

        return -logdet  # maximise entropy → minimise -logdet

    # ===== 2) stable‑rank ratio  (sigma_max² / Frobenius²) ===
    def iso_stable_rank(self, features, augmented_features=True, n_power_iter=5):

        bsz, n_views, dim = features.shape

        if self.augmented_features:
            features = features.view(bsz * n_views, dim)
        else:
            features = features[:, 0, :]

        features = features.T 

        """Stable‑rank inverse penalty: sigma_max² / Frobenius²."""
        # Frobenius norm squared
        frob2 = features.pow(2).sum()

        # estimate top singular value with power iteration
        v = torch.randn(features.shape[1], device=features.device)
        v = v / v.norm()
        for _ in range(n_power_iter):
            v = (features.T @ (features @ v))        # (B,) vector
            v = v / v.norm()
        sigma_max = (features @ v).norm()        # spectral norm

        return (sigma_max ** 2) / frob2

    def flatten_spectrum_feature_loss_old(self, features, alpha=0.1):
        
        bsz, n_views, dim = features.shape
        if self.augmented_features == True: 
            features = features.view(bsz * n_views, dim)
        else: 
            features = features[:, 0, :]

        U, S, V = torch.svd(features)
        
        # S_normalized = S / S.max()
        
        # spectrum_loss = torch.var(S_normalized)
        # spectrum_loss = S_normalized.max() / (S_normalized.min() + 1e-8)
        # spectrum_loss = torch.mean((S_normalized - 1) ** 2)

        

        # C = features.T @ features  # (embed_dim, embed_dim)
        # C_norm = C / (bsz - 1.0)  # optional normalization

        # off_diag = (C_norm - torch.diag(torch.diag(C_norm)))**2
        # diag = (torch.diag(C_norm) - 1)**2

        # loss_off_diag = off_diag.sum()   # penalize correlation
        # loss_diag = diag.sum()           # penalize deviation from unit variance

        # spectrum_loss = loss_off_diag + alpha * loss_diag

        spectrum_loss = - torch.sum(torch.log(S + 1e-12))

        return spectrum_loss

    def flatten_spectrum_grammatrix_loss(self, features):
        """
        Compute the spectrum flattening regularizer for a matrix.

        This regularizer minimizes the ratio of the largest singular value to the smallest singular value
        to encourage all singular values to be similar.

        Args:
            matrix (torch.Tensor): Input matrix (e.g., the Gram matrix) of shape (n, n).

        Returns:
            torch.Tensor: The regularization loss value.
        """
        bsz, n_views, dim = features.shape
        # print(features.shape)
        # print(self.augmented_features)
        if self.augmented_features == True: 
            features = features.view(bsz * n_views, dim)
        else: 
            features = features[:, 0, :]

        # print(features.shape)
        gram_matrix = compute_gram_matrix(features)
        # print(gram_matrix.shape)

        # Compute singular values of the Gram matrix
        singular_values = torch.linalg.svdvals(gram_matrix)
        singular_values = singular_values / singular_values.max()
        # print(singular_values.shape)
        
        # Avoid division by zero by adding a small epsilon
        epsilon = 1e-8

        # Compute the regularizer as the ratio of the largest to the smallest singular value
        # spectrum_loss = singular_values.max() / (singular_values.min() + epsilon)

        
        k = 5  # Number of smallest eigenvalues to target
        smallest_eigenvalues = singular_values[:k]

        # target = singular_values.mean()  # Target: make small eigenvalues closer to the mean
        # spectrum_loss = torch.mean((smallest_eigenvalues - target) ** 2)

        # Add epsilon to avoid division by zero
        inverse_smallest = 1.0 / (smallest_eigenvalues + epsilon)

        # Regularization loss: Penalize large inverses
        spectrum_loss = torch.mean(inverse_smallest)

        return spectrum_loss

class SimSiamLoss(nn.Module):
    def __init__(self, augmented_features=True, decor_reg=0.1, decor_loss=True, 
                 ent_reg=0.1, ent_loss=True, spec_reg=0.1, spec_loss=True):
        super(SimSiamLoss, self).__init__()
        self.augmented_features = augmented_features
        self.decor_reg = decor_reg
        self.decor_loss = decor_loss
        self.ent_reg = ent_reg
        self.ent_loss = ent_loss
        self.spec_reg = spec_reg
        self.spec_loss = spec_loss
    
    def forward(self, h1, p1, p2, z1, z2):
        # Stop gradients for z2
        # z1 = z1.detach()
        # z2 = z2.detach()
        
        # Negative cosine similarity
        loss1 = -torch.mean(nn.functional.cosine_similarity(p1, z2.detach(), dim=-1))
        loss2 = -torch.mean(nn.functional.cosine_similarity(p2, z1.detach(), dim=-1))

        loss = 0.5 * (loss1 + loss2)

        if self.decor_loss == True: 
            decor_loss_1, decor_loss_2 = self.decorrelation_loss(h1)
        else: 
            decor_loss_1, decor_loss_2 = 0, 0

        if self.ent_loss == True:
            entropy_loss = self.entropy_loss(h1)
        else:
            entropy_loss = 0 

        if self.spec_loss == True:
            spectral_flattening_loss = self.flatten_spectrum_feature_loss(h1)
            # spectral_flattening_loss = 0.5 * (spectral_flattening_h1 + spectral_flattening_h2)
        else: 
            spectral_flattening_loss = 0

        total_loss = loss + self.decor_reg * decor_loss_1 + self.decor_reg * decor_loss_2 + self.ent_reg * entropy_loss + self.spec_reg * spectral_flattening_loss
        
        return total_loss, torch.tensor([decor_loss_1+decor_loss_2]), torch.tensor([entropy_loss]), torch.tensor([spectral_flattening_loss])

    def decorrelation_loss(self, features):
        """Compute decorrelation loss."""
        # Flatten features across all views
        bsz, n_views, dim = features.shape
        if self.augmented_features == True:
            features = features.view(bsz * n_views, dim)
        else: 
            features = features[:, 0, :]
        # Compute the covariance matrix
        features = features - features.mean(dim=0, keepdim=True)  # Center features

        # Compute the Gram matrix (feature covariance)
        cov = (features.T @ features) / (features.shape[0] - 1)

        # Identity matrix for the same dimension as the Gram matrix
        identity_matrix = torch.eye(cov.size(0), device=features.device)

        # Frobenius norm of the difference between Gram matrix and Identity matrix
        decor_loss_1 = torch.norm(cov - identity_matrix, p='fro')
        
        return decor_loss_1, 0

    def entropy_loss(self, features):
        # perform SVD to get singular values
        bsz, n_views, dim = features.shape
        if self.augmented_features == True: 
            features = features.view(bsz * n_views, dim)
        else: 
            features = features[:, 0, :]
        U, S, V = torch.svd(features)

        # normalize singular values to form a probability distribution
        S_normalized = S / torch.sum(S)
    
        # calculate entropy
        entropy = -torch.sum(S_normalized * torch.log(S_normalized + 1e-6))
        return -entropy

    def flatten_spectrum_feature_loss(self, features):

        F_centered = features - features.mean(dim=0, keepdim=True)

        # Compute SVD
        U, S, Vh = torch.linalg.svd(F_centered, full_matrices=False)
        # S_norm = S / S.sum()
        S_norm = S / S.max()

        # k = 5  # Number of smallest eigenvalues to target
        # large_s_norm = S_norm[:k]
        # inverse_largest = 1.0 / (large_s_norm + 1e-6)
        # spectrum_loss = torch.mean(inverse_largest)

        spectrum_loss = torch.mean((S_norm - 1) ** 2)

        return spectrum_loss

    def flatten_spectrum_grammatrix_loss(self, features):
        """
        Compute the spectrum flattening regularizer for a matrix.

        This regularizer minimizes the ratio of the largest singular value to the smallest singular value
        to encourage all singular values to be similar.

        Args:
            matrix (torch.Tensor): Input matrix (e.g., the Gram matrix) of shape (n, n).

        Returns:
            torch.Tensor: The regularization loss value.
        """
        gram_matrix = compute_gram_matrix(features)
        
        singular_values = torch.linalg.svdvals(gram_matrix)
        singular_values = singular_values / singular_values.max()
        epsilon = 1e-8
        k = 5  # Number of smallest eigenvalues to target
        smallest_eigenvalues = singular_values[:k]
        inverse_smallest = 1.0 / (smallest_eigenvalues + epsilon)
        spectrum_loss = torch.mean(inverse_smallest)

        return spectrum_loss

    # def spectrum_flattening_loss(self, features):
        
    #     bsz, n_views, dim = features.shape
    #     if self.augmented_features == True: 
    #         features = features.view(bsz * n_views, dim)
    #     else: 
    #         features = features[:, 0, :]

    #     U, S, V = torch.svd(features)
    #     S_normalized = S / torch.max(S)
        
    #     spectrum_loss = torch.var(S_normalized)

    #     return spectrum_loss

    # def flatten_spectrum_loss(self, features):
    #     """
    #     Compute the spectrum flattening regularizer for a matrix.

    #     This regularizer minimizes the ratio of the largest singular value to the smallest singular value
    #     to encourage all singular values to be similar.

    #     Args:
    #         matrix (torch.Tensor): Input matrix (e.g., the Gram matrix) of shape (n, n).

    #     Returns:
    #         torch.Tensor: The regularization loss value.
    #     """
    #     bsz, n_views, dim = features.shape
    #     if self.augmented_features == True: 
    #         features = features.view(bsz * n_views, dim)
    #     else: 
    #         features = features[:, 0, :]

    #     gram_matrix = compute_gram_matrix(features)

    #     # Compute singular values of the Gram matrix
    #     singular_values = torch.linalg.svdvals(gram_matrix)
    #     singular_values = singular_values / singular_values.max()
            
    #     # Avoid division by zero by adding a small epsilon
    #     epsilon = 1e-8

            
    #     # Compute the regularizer as the ratio of the largest to the smallest singular value
    #     spectrum_loss = singular_values.max() / (singular_values.min() + epsilon)
            
    #     return spectrum_loss

def off_diagonal(x):
    # Returns a flattened view of the off-diagonal elements of a square matrix.
    n, m = x.shape
    assert n == m, "Input must be a square matrix."
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class BarlowTwinsLoss(nn.Module):
    def __init__(self, lambda_param=5e-3):
        super(BarlowTwinsLoss, self).__init__()
        self.lambda_param = lambda_param

    def forward(self, z):
        """
        Compute the Barlow Twins loss.

        Args:
            z (torch.Tensor): Feature representations of shape 
                              [batch_size, n_views, feature_dim], where n_views should be 2.
                              
        Returns:
            torch.Tensor: The computed loss.
        """
        # Ensure two views are provided.
        assert z.dim() == 3 and z.shape[1] == 2, "Input tensor must have shape [B, 2, D]"
        
        # Split the two views.
        z1, z2 = z[:, 0], z[:, 1]  # Both have shape [B, D]

        # Normalize each view along the batch dimension.
        z1_norm = (z1 - z1.mean(0)) / (z1.std(0) + 1e-9)
        z2_norm = (z2 - z2.mean(0)) / (z2.std(0) + 1e-9)
        
        # Compute cross-correlation matrix.
        batch_size = z1.shape[0]
        c = torch.mm(z1_norm.T, z2_norm) / batch_size  # Shape: [D, D]

        # Loss: Invariance loss (on-diagonal) and redundancy reduction loss (off-diagonal).
        on_diag = torch.diagonal(c).add(-1).pow(2).sum()
        off_diag = off_diagonal(c).pow(2).sum()
        loss = on_diag + self.lambda_param * off_diag
        
        return loss, torch.tensor([0]), torch.tensor([0]), torch.tensor([0])



class DirectDLRLoss(nn.Module):
    def __init__(self, temperature=0.5):
        """
        DirectDLR loss: cosine similarity with cross-entropy over sub-vectors.

        Args:
            temperature (float): Scaling for cosine similarity.
        """
        super(DirectDLRLoss, self).__init__()
        self.temperature = temperature

    def forward(self, h1, h2):
        """
        Computes the loss between two views' sub-vectors.

        Args:
            h1: Tensor of shape [B, k], from first view.
            h2: Tensor of shape [B, k], from second view.

        Returns:
            Scalar loss (Tensor).
        """
        B = h1.shape[0]

        # Normalize features
        h1 = F.normalize(h1, dim=1)
        h2 = F.normalize(h2, dim=1)

        # Cosine similarity matrix between views
        sim = torch.mm(h1, h2.T) / self.temperature  # shape: [B, B]

        # Labels: positive pair on diagonal
        labels = torch.arange(B).to(h1.device)

        # Cross-entropy loss (view 1 → view 2)
        loss = F.cross_entropy(sim, labels)

        return loss, torch.tensor([0]), torch.tensor([0]), torch.tensor([0])
